{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# HugeCTR demo on Movie lens data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "HugeCTR is a recommender-specific framework that is capable of distributed training across multiple GPUs and nodes for click-through-rate (CTR) estimation.\n",
    "HugeCTR is a component of NVIDIA Merlin ([documentation](https://nvidia-merlin.github.io/Merlin/main/README.html) | [GitHub](https://github.com/NVIDIA-Merlin/Merlin)).\n",
    "Merlin which is a framework that accelerates the entire pipeline from data ingestion and training to deploying GPU-accelerated recommender systems.\n",
    "\n",
    "### Learning objectives\n",
    "\n",
    "* Training a deep-learning recommender model (DLRM) on the MovieLens 20M [dataset](https://grouplens.org/datasets/movielens/20m/).\n",
    "* Walk through data preprocessing, training a DLRM model with HugeCTR, and then using the movie embedding to answer item similarity queries.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "### Docker containers\n",
    "\n",
    "Start the notebook inside a running 22.05 or later NGC Docker container: `nvcr.io/nvidia/merlin/merlin-training:22.05`.\n",
    "The HugeCTR Python interface is installed to the path `/usr/local/hugectr/lib/` and the path is added to the environment variable `PYTHONPATH`.\n",
    "You can use the HugeCTR Python interface within the Docker container without any additional configuration.\n",
    "\n",
    "### Hardware\n",
    "\n",
    "This notebook requires a Pascal, Volta, Turing, Ampere or newer GPUs, such as P100, V100, T4 or A100.\n",
    "You can view the GPU information with the `nvidia-smi` command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mon Jul 12 06:54:46 2021       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 450.51.06    Driver Version: 450.51.06    CUDA Version: 11.3     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla V100-PCIE...  On   | 00000000:1A:00.0 Off |                    0 |\n",
      "| N/A   29C    P0    23W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla V100-PCIE...  On   | 00000000:1B:00.0 Off |                    0 |\n",
      "| N/A   27C    P0    22W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  Tesla V100-PCIE...  On   | 00000000:3D:00.0 Off |                    0 |\n",
      "| N/A   26C    P0    23W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  Tesla V100-PCIE...  On   | 00000000:3E:00.0 Off |                    0 |\n",
      "| N/A   28C    P0    23W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   4  Tesla V100-PCIE...  On   | 00000000:88:00.0 Off |                    0 |\n",
      "| N/A   25C    P0    24W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   5  Tesla V100-PCIE...  On   | 00000000:89:00.0 Off |                    0 |\n",
      "| N/A   25C    P0    22W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   6  Tesla V100-PCIE...  On   | 00000000:B1:00.0 Off |                    0 |\n",
      "| N/A   26C    P0    23W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   7  Tesla V100-PCIE...  On   | 00000000:B2:00.0 Off |                    0 |\n",
      "| N/A   25C    P0    24W / 250W |      0MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|  No running processes found                                                 |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data download and preprocessing\n",
    "\n",
    "We first install a few extra utilities for data preprocessing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and installing 'tqdm' package.\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "Downloading and installing 'unzip' command\n",
      "Collecting package metadata (current_repodata.json): ...working... done\n",
      "Solving environment: ...working... done\n",
      "\n",
      "## Package Plan ##\n",
      "\n",
      "  environment location: /opt/conda\n",
      "\n",
      "  added / updated specs:\n",
      "    - unzip\n",
      "\n",
      "\n",
      "The following packages will be downloaded:\n",
      "\n",
      "    package                    |            build\n",
      "    ---------------------------|-----------------\n",
      "    unzip-6.0                  |       h7f98852_2         143 KB  conda-forge\n",
      "    ------------------------------------------------------------\n",
      "                                           Total:         143 KB\n",
      "\n",
      "The following NEW packages will be INSTALLED:\n",
      "\n",
      "  unzip              conda-forge/linux-64::unzip-6.0-h7f98852_2\n",
      "\n",
      "\n",
      "Preparing transaction: ...working... done\n",
      "Verifying transaction: ...working... done\n",
      "Executing transaction: ...working... done\n"
     ]
    }
   ],
   "source": [
    "print(\"Downloading and installing 'tqdm' package.\")\n",
    "!pip3 -q install torch tqdm\n",
    "\n",
    "print(\"Downloading and installing 'unzip' command\")\n",
    "!conda install -y -q -c conda-forge unzip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we download and unzip the MovieLens 20M [dataset](https://grouplens.org/datasets/movielens/20m/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and extracting 'Movie Lens 20M' dataset.\n",
      "ml-20m.zip          100%[===================>] 189.50M  46.1MB/s    in 4.5s    \n",
      "Archive:  data/ml-20m.zip\n",
      "   creating: data/ml-20m/\n",
      "  inflating: data/ml-20m/genome-scores.csv  \n",
      "  inflating: data/ml-20m/genome-tags.csv  \n",
      "  inflating: data/ml-20m/links.csv   \n",
      "  inflating: data/ml-20m/movies.csv  \n",
      "  inflating: data/ml-20m/ratings.csv  \n",
      "  inflating: data/ml-20m/README.txt  \n",
      "  inflating: data/ml-20m/tags.csv    \n",
      "ml-20m\tml-20m.zip\n"
     ]
    }
   ],
   "source": [
    "print(\"Downloading and extracting 'Movie Lens 20M' dataset.\")\n",
    "!wget -nc http://files.grouplens.org/datasets/movielens/ml-20m.zip -P data -q --show-progress\n",
    "!unzip -n data/ml-20m.zip -d data\n",
    "!ls ./data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MovieLens data preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "import tqdm\n",
    "\n",
    "MIN_RATINGS = 20\n",
    "USER_COLUMN = 'userId'\n",
    "ITEM_COLUMN = 'movieId'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we read the data into a Pandas dataframe and encode `userID` and `itemID` with integers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Filtering out users with less than 20 ratings\n",
      "Mapping original user and item IDs to new sequential IDs\n",
      "Number of users: 138493\n",
      "Number of items: 26744\n"
     ]
    }
   ],
   "source": [
    "df = pd.read_csv('./data/ml-20m/ratings.csv')\n",
    "print(\"Filtering out users with less than {} ratings\".format(MIN_RATINGS))\n",
    "grouped = df.groupby(USER_COLUMN)\n",
    "df = grouped.filter(lambda x: len(x) >= MIN_RATINGS)\n",
    "\n",
    "print(\"Mapping original user and item IDs to new sequential IDs\")\n",
    "df[USER_COLUMN], unique_users = pd.factorize(df[USER_COLUMN])\n",
    "df[ITEM_COLUMN], unique_items = pd.factorize(df[ITEM_COLUMN])\n",
    "\n",
    "nb_users = len(unique_users)\n",
    "nb_items = len(unique_items)\n",
    "\n",
    "print(\"Number of users: %d\\nNumber of items: %d\"%(len(unique_users), len(unique_items)))\n",
    "\n",
    "# Save the mapping to do the inference later on\n",
    "import pickle\n",
    "with open('./mappings.pickle', 'wb') as handle:\n",
    "    pickle.dump({\"users\": unique_users, \"items\": unique_items}, handle, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we split the data into a train and test set.\n",
    "The last movie each user has recently rated is used for the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Need to sort before popping to get the last item\n",
    "df.sort_values(by='timestamp', inplace=True)\n",
    "    \n",
    "# clean up data\n",
    "del df['rating'], df['timestamp']\n",
    "df = df.drop_duplicates() # assuming it keeps order\n",
    "\n",
    "# now we have filtered and sorted by time data, we can split test data out\n",
    "grouped_sorted = df.groupby(USER_COLUMN, group_keys=False)\n",
    "test_data = grouped_sorted.tail(1).sort_values(by=USER_COLUMN)\n",
    "\n",
    "# need to pop for each group\n",
    "train_data = grouped_sorted.apply(lambda x: x.iloc[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0</td>\n",
       "      <td>19</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>86</th>\n",
       "      <td>0</td>\n",
       "      <td>86</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>0</td>\n",
       "      <td>61</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0</td>\n",
       "      <td>23</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    userId  movieId  target\n",
       "20       0       20       1\n",
       "19       0       19       1\n",
       "86       0       86       1\n",
       "61       0       61       1\n",
       "23       0       23       1"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data['target']=1\n",
    "test_data['target']=1\n",
    "train_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because the MovieLens data contains only positive examples, first we define a utility function to generate negative samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _TestNegSampler:\n",
    "    def __init__(self, train_ratings, nb_users, nb_items, nb_neg):\n",
    "        self.nb_neg = nb_neg\n",
    "        self.nb_users = nb_users \n",
    "        self.nb_items = nb_items \n",
    "\n",
    "        # compute unique ids for quickly created hash set and fast lookup\n",
    "        ids = (train_ratings[:, 0] * self.nb_items) + train_ratings[:, 1]\n",
    "        self.set = set(ids)\n",
    "\n",
    "    def generate(self, batch_size=128*1024):\n",
    "        users = torch.arange(0, self.nb_users).reshape([1, -1]).repeat([self.nb_neg, 1]).transpose(0, 1).reshape(-1)\n",
    "\n",
    "        items = [-1] * len(users)\n",
    "\n",
    "        random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()\n",
    "        print('Generating validation negatives...')\n",
    "        for idx, u in enumerate(tqdm.tqdm(users.tolist())):\n",
    "            if not random_items:\n",
    "                random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()\n",
    "            j = random_items.pop()\n",
    "            while u * self.nb_items + j in self.set:\n",
    "                if not random_items:\n",
    "                    random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()\n",
    "                j = random_items.pop()\n",
    "\n",
    "            items[idx] = j\n",
    "        items = torch.LongTensor(items)\n",
    "        return items"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we generate the negative samples for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating validation negatives...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 69246500/69246500 [00:44<00:00, 1566380.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating validation negatives...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13849300/13849300 [00:08<00:00, 1594800.54it/s]\n"
     ]
    }
   ],
   "source": [
    "sampler = _TestNegSampler(df.values, nb_users, nb_items, 500)  # using 500 negative samples\n",
    "train_negs = sampler.generate()\n",
    "train_negs = train_negs.reshape(-1, 500)\n",
    "\n",
    "sampler = _TestNegSampler(df.values, nb_users, nb_items, 100)  # using 100 negative samples\n",
    "test_negs = sampler.generate()\n",
    "test_negs = test_negs.reshape(-1, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 138493/138493 [04:07<00:00, 558.71it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 138493/138493 [00:49<00:00, 2819.57it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# generating negative samples for training\n",
    "train_data_neg = np.zeros((train_negs.shape[0]*train_negs.shape[1],3), dtype=int)\n",
    "idx = 0\n",
    "for i in tqdm.tqdm(range(train_negs.shape[0])):\n",
    "    for j in range(train_negs.shape[1]):\n",
    "        train_data_neg[idx, 0] = i # user ID\n",
    "        train_data_neg[idx, 1] = train_negs[i, j] # negative item ID\n",
    "        idx += 1\n",
    "    \n",
    "# generating negative samples for testing\n",
    "test_data_neg = np.zeros((test_negs.shape[0]*test_negs.shape[1],3), dtype=int)\n",
    "idx = 0\n",
    "for i in tqdm.tqdm(range(test_negs.shape[0])):\n",
    "    for j in range(test_negs.shape[1]):\n",
    "        test_data_neg[idx, 0] = i\n",
    "        test_data_neg[idx, 1] = test_negs[i, j]\n",
    "        idx += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_np= np.concatenate([train_data_neg, train_data.values])\n",
    "np.random.shuffle(train_data_np)\n",
    "\n",
    "test_data_np= np.concatenate([test_data_neg, test_data.values])\n",
    "np.random.shuffle(test_data_np)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HugeCTR expect user ID and item ID to be different, so we use 0 -> nb_users for user IDs and\n",
    "# nb_users -> nb_users+nb_items for item IDs.\n",
    "train_data_np[:,1] += nb_users \n",
    "test_data_np[:,1] += nb_users "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "165236"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(train_data_np[:,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Write HugeCTR data files\n",
    "\n",
    "After pre-processing, we write the data to disk using HugeCTR the [Norm](https://nvidia-merlin.github.io/HugeCTR/master/api/python_interface.html#norm) dataset format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ctypes import c_longlong as ll\n",
    "from ctypes import c_uint\n",
    "from ctypes import c_float\n",
    "from ctypes import c_int\n",
    "\n",
    "def write_hugeCTR_data(huge_ctr_data, filename='huge_ctr_data.dat'):\n",
    "    print(\"Writing %d samples\"%huge_ctr_data.shape[0])\n",
    "    with open(filename, 'wb') as f:\n",
    "        #write header\n",
    "        f.write(ll(0)) # 0: no error check; 1: check_num\n",
    "        f.write(ll(huge_ctr_data.shape[0])) # the number of samples in this data file\n",
    "        f.write(ll(1)) # dimension of label\n",
    "        f.write(ll(1)) # dimension of dense feature\n",
    "        f.write(ll(2)) # long long slot_num\n",
    "        for _ in range(3): f.write(ll(0)) # reserved for future use\n",
    "\n",
    "        for i in tqdm.tqdm(range(huge_ctr_data.shape[0])):\n",
    "            f.write(c_float(huge_ctr_data[i,2])) # float label[label_dim];\n",
    "            f.write(c_float(0)) # dummy dense feature\n",
    "            f.write(c_int(1)) # slot 1 nnz: user ID\n",
    "            f.write(c_uint(huge_ctr_data[i,0]))\n",
    "            f.write(c_int(1)) # slot 2 nnz: item ID\n",
    "            f.write(c_uint(huge_ctr_data[i,1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_filelist(filelist_name, num_files, filename_prefix):\n",
    "    with open(filelist_name, 'wt') as f:\n",
    "        f.write('{0}\\n'.format(num_files));\n",
    "        for i in range(num_files):\n",
    "            f.write('{0}_{1}.dat\\n'.format(filename_prefix, i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:17<00:00, 513695.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 526049.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 525218.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 528084.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 525638.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 528931.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 531191.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 532537.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:16<00:00, 528103.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 8910827 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8910827/8910827 [00:17<00:00, 522249.44it/s]\n"
     ]
    }
   ],
   "source": [
    "!rm -rf ./data/hugeCTR\n",
    "!mkdir ./data/hugeCTR\n",
    "\n",
    "for i, data_arr in enumerate(np.array_split(train_data_np,10)):\n",
    "    write_hugeCTR_data(data_arr, filename='./data/hugeCTR/train_huge_ctr_data_%d.dat'%i)\n",
    "\n",
    "generate_filelist('./data/hugeCTR/train_filelist.txt', 10, './data/hugeCTR/train_huge_ctr_data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398780 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398780/1398780 [00:02<00:00, 510667.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398780 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398780/1398780 [00:02<00:00, 523734.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398780 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398780/1398780 [00:02<00:00, 512399.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 519540.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 522322.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 525051.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 527603.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 521668.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 517335.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing 1398779 samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1398779/1398779 [00:02<00:00, 522761.79it/s]\n"
     ]
    }
   ],
   "source": [
    "for i, data_arr in enumerate(np.array_split(test_data_np,10)):\n",
    "    write_hugeCTR_data(data_arr, filename='./data/hugeCTR/test_huge_ctr_data_%d.dat'%i)\n",
    "    \n",
    "generate_filelist('./data/hugeCTR/test_filelist.txt', 10, './data/hugeCTR/test_huge_ctr_data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HugeCTR DLRM training\n",
    "\n",
    "In this section, we will train a DLRM network on the augmented movie lens data. First, we write the training Python script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting hugectr_dlrm_movielens.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile hugectr_dlrm_movielens.py\n",
    "import hugectr\n",
    "from mpi4py import MPI\n",
    "solver = hugectr.CreateSolver(max_eval_batches = 1000,\n",
    "                              batchsize_eval = 65536,\n",
    "                              batchsize = 65536,\n",
    "                              lr = 0.1,\n",
    "                              warmup_steps = 1000,\n",
    "                              decay_start = 10000,\n",
    "                              decay_steps = 40000,\n",
    "                              decay_power = 2.0,\n",
    "                              end_lr = 1e-5,\n",
    "                              vvgpu = [[0]],\n",
    "                              repeat_dataset = True,\n",
    "                              use_mixed_precision = True,\n",
    "                              scaler = 1024)\n",
    "reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,\n",
    "                                  source = [\"./data/hugeCTR/train_filelist.txt\"],\n",
    "                                  eval_source = \"./data/hugeCTR/test_filelist.txt\",\n",
    "                                  check_type = hugectr.Check_t.Non)\n",
    "optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.SGD,\n",
    "                                    update_type = hugectr.Update_t.Local,\n",
    "                                    atomic_update = True)\n",
    "model = hugectr.Model(solver, reader, optimizer)\n",
    "model.add(hugectr.Input(label_dim = 1, label_name = \"label\",\n",
    "                        dense_dim = 1, dense_name = \"dense\",\n",
    "                        data_reader_sparse_param_array = \n",
    "                        [hugectr.DataReaderSparseParam(\"data1\", 1, True, 2)]))\n",
    "model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash, \n",
    "                            workspace_size_per_gpu_in_mb = 41,\n",
    "                            embedding_vec_size = 64,\n",
    "                            combiner = \"sum\",\n",
    "                            sparse_embedding_name = \"sparse_embedding1\",\n",
    "                            bottom_name = \"data1\",\n",
    "                            optimizer = optimizer))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"dense\"],\n",
    "                            top_names = [\"fc1\"],\n",
    "                            num_output=64))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"fc1\"],\n",
    "                            top_names = [\"fc2\"],\n",
    "                            num_output=128))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"fc2\"],\n",
    "                            top_names = [\"fc3\"],\n",
    "                            num_output=64))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Interaction,\n",
    "                            bottom_names = [\"fc3\",\"sparse_embedding1\"],\n",
    "                            top_names = [\"interaction1\"]))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"interaction1\"],\n",
    "                            top_names = [\"fc4\"],\n",
    "                            num_output=1024))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"fc4\"],\n",
    "                            top_names = [\"fc5\"],\n",
    "                            num_output=1024))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"fc5\"],\n",
    "                            top_names = [\"fc6\"],\n",
    "                            num_output=512))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FusedInnerProduct,\n",
    "                            bottom_names = [\"fc6\"],\n",
    "                            top_names = [\"fc7\"],\n",
    "                            num_output=256))                                                  \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"fc7\"],\n",
    "                            top_names = [\"fc8\"],\n",
    "                            num_output=1))                                                                                           \n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,\n",
    "                            bottom_names = [\"fc8\", \"label\"],\n",
    "                            top_names = [\"loss\"]))\n",
    "model.compile()\n",
    "model.summary()\n",
    "model.fit(max_iter = 50000, display = 1000, eval_interval = 3000, snapshot = 3000, snapshot_prefix = \"./hugeCTR_saved_model_DLRM/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf ./hugeCTR_saved_model_DLRM/\n",
    "!mkdir ./hugeCTR_saved_model_DLRM/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================Model Init=====================================================\n",
      "[12d06h55m13s][HUGECTR][INFO]: Global seed is 2552343530\n",
      "[12d06h55m15s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.\n",
      "Device 0: Tesla V100-PCIE-32GB\n",
      "[12d06h55m15s][HUGECTR][INFO]: num of DataReader workers: 12\n",
      "[12d06h55m15s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=167936\n",
      "[12d06h55m15s][HUGECTR][INFO]: All2All Warmup Start\n",
      "[12d06h55m15s][HUGECTR][INFO]: All2All Warmup End\n",
      "===================================================Model Compile===================================================\n",
      "[12d06h56m10s][HUGECTR][INFO]: gpu0 start to init embedding\n",
      "[12d06h56m10s][HUGECTR][INFO]: gpu0 init embedding done\n",
      "===================================================Model Summary===================================================\n",
      "Label                                   Dense                         Sparse                        \n",
      "label                                   dense                          data1                         \n",
      "(None, 1)                               (None, 1)                               \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "Layer Type                              Input Name                    Output Name                   Output Shape                  \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "LocalizedSlotSparseEmbeddingHash        data1                         sparse_embedding1             (None, 2, 64)                 \n",
      "FusedInnerProduct                       dense                         fc1                           (None, 64)                    \n",
      "FusedInnerProduct                       fc1                           fc2                           (None, 128)                   \n",
      "FusedInnerProduct                       fc2                           fc3                           (None, 64)                    \n",
      "Interaction                             fc3,sparse_embedding1         interaction1                  (None, 68)                    \n",
      "FusedInnerProduct                       interaction1                  fc4                           (None, 1024)                  \n",
      "FusedInnerProduct                       fc4                           fc5                           (None, 1024)                  \n",
      "FusedInnerProduct                       fc5                           fc6                           (None, 512)                   \n",
      "FusedInnerProduct                       fc6                           fc7                           (None, 256)                   \n",
      "InnerProduct                            fc7                           fc8                           (None, 1)                     \n",
      "BinaryCrossEntropyLoss                  fc8,label                     loss                                                        \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "=====================================================Model Fit=====================================================\n",
      "[12d60h56m10s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 50000\n",
      "[12d60h56m10s][HUGECTR][INFO]: Training batchsize: 65536, evaluation batchsize: 65536\n",
      "[12d60h56m10s][HUGECTR][INFO]: Evaluation interval: 3000, snapshot interval: 3000\n",
      "[12d60h56m10s][HUGECTR][INFO]: Sparse embedding trainable: 1, dense network trainable: 1\n",
      "[12d60h56m10s][HUGECTR][INFO]: Use mixed precision: 1, scaler: 1024.000000, use cuda graph: 1\n",
      "[12d60h56m10s][HUGECTR][INFO]: lr: 0.100000, warmup_steps: 1000, decay_start: 10000, decay_steps: 40000, decay_power: 2.000000, end_lr: 0.000010\n",
      "[12d60h56m10s][HUGECTR][INFO]: Training source file: ./data/hugeCTR/train_filelist.txt\n",
      "[12d60h56m10s][HUGECTR][INFO]: Evaluation source file: ./data/hugeCTR/test_filelist.txt\n",
      "[12d60h56m25s][HUGECTR][INFO]: Iter: 1000 Time(1000 iters): 14.895018s Loss: 0.534868 lr:0.100000\n",
      "[12d60h56m40s][HUGECTR][INFO]: Iter: 2000 Time(1000 iters): 14.917098s Loss: 0.526272 lr:0.100000\n",
      "[12d60h56m55s][HUGECTR][INFO]: Iter: 3000 Time(1000 iters): 14.945527s Loss: 0.504054 lr:0.100000\n",
      "[12d60h57m10s][HUGECTR][INFO]: Evaluation, AUC: 0.698215\n",
      "[12d60h57m10s][HUGECTR][INFO]: Eval Time for 1000 iters: 5.962128s\n",
      "[12d60h57m10s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d60h57m10s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d60h57m10s][HUGECTR][INFO]: Done\n",
      "[12d60h57m10s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d60h57m10s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d60h57m10s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d60h57m10s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d60h57m10s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d60h57m16s][HUGECTR][INFO]: Iter: 4000 Time(1000 iters): 21.357401s Loss: 0.286658 lr:0.100000\n",
      "[12d60h57m31s][HUGECTR][INFO]: Iter: 5000 Time(1000 iters): 15.037847s Loss: 0.249509 lr:0.100000\n",
      "[12d60h57m46s][HUGECTR][INFO]: Iter: 6000 Time(1000 iters): 15.048834s Loss: 0.239949 lr:0.100000\n",
      "[12d60h57m52s][HUGECTR][INFO]: Evaluation, AUC: 0.928999\n",
      "[12d60h57m52s][HUGECTR][INFO]: Eval Time for 1000 iters: 5.993647s\n",
      "[12d60h57m52s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d60h57m52s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d60h57m52s][HUGECTR][INFO]: Done\n",
      "[12d60h57m52s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d60h57m52s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d60h57m52s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d60h57m52s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d60h57m52s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d60h58m80s][HUGECTR][INFO]: Iter: 7000 Time(1000 iters): 21.364920s Loss: 0.242271 lr:0.100000\n",
      "[12d60h58m23s][HUGECTR][INFO]: Iter: 8000 Time(1000 iters): 15.036863s Loss: 0.236050 lr:0.100000\n",
      "[12d60h58m38s][HUGECTR][INFO]: Iter: 9000 Time(1000 iters): 15.042685s Loss: 0.235748 lr:0.100000\n",
      "[12d60h58m44s][HUGECTR][INFO]: Evaluation, AUC: 0.937590\n",
      "[12d60h58m44s][HUGECTR][INFO]: Eval Time for 1000 iters: 5.990306s\n",
      "[12d60h58m44s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d60h58m44s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d60h58m44s][HUGECTR][INFO]: Done\n",
      "[12d60h58m44s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d60h58m44s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d60h58m44s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d60h58m44s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d60h58m44s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d60h58m59s][HUGECTR][INFO]: Iter: 10000 Time(1000 iters): 21.408894s Loss: 0.233947 lr:0.099995\n",
      "[12d60h59m14s][HUGECTR][INFO]: Iter: 11000 Time(1000 iters): 15.050379s Loss: 0.231177 lr:0.095058\n",
      "[12d60h59m29s][HUGECTR][INFO]: Iter: 12000 Time(1000 iters): 15.047381s Loss: 0.230662 lr:0.090245\n",
      "[12d60h59m35s][HUGECTR][INFO]: Evaluation, AUC: 0.940782\n",
      "[12d60h59m35s][HUGECTR][INFO]: Eval Time for 1000 iters: 5.990065s\n",
      "[12d60h59m35s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d60h59m35s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d60h59m35s][HUGECTR][INFO]: Done\n",
      "[12d60h59m36s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d60h59m36s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d60h59m36s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d60h59m36s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d60h59m36s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d60h59m51s][HUGECTR][INFO]: Iter: 13000 Time(1000 iters): 21.492720s Loss: 0.229246 lr:0.085558\n",
      "[12d70h00m60s][HUGECTR][INFO]: Iter: 14000 Time(1000 iters): 15.051535s Loss: 0.227302 lr:0.080996\n",
      "[12d70h00m21s][HUGECTR][INFO]: Iter: 15000 Time(1000 iters): 15.062830s Loss: 0.22.057 lr:0.076558\n",
      "[12d70h00m27s][HUGECTR][INFO]: Evaluation, AUC: 0.941291\n",
      "[12d70h00m27s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.004500s\n",
      "[12d70h00m27s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h00m27s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h00m27s][HUGECTR][INFO]: Done\n",
      "[12d70h00m27s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h00m27s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h00m27s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h00m27s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h00m27s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h00m42s][HUGECTR][INFO]: Iter: 16000 Time(1000 iters): 21.480675s Loss: 0.220782 lr:0.072246\n",
      "[12d70h00m57s][HUGECTR][INFO]: Iter: 17000 Time(1000 iters): 15.057642s Loss: 0.214406 lr:0.068058\n",
      "[12d70h10m12s][HUGECTR][INFO]: Iter: 18000 Time(1000 iters): 15.068874s Loss: 0.211810 lr:0.063996\n",
      "[12d70h10m18s][HUGECTR][INFO]: Evaluation, AUC: 0.943403\n",
      "[12d70h10m18s][HUGECTR][INFO]: Eval Time for 1000 iters: 5.994943s\n",
      "[12d70h10m18s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h10m18s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h10m18s][HUGECTR][INFO]: Done\n",
      "[12d70h10m19s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h10m19s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h10m19s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h10m19s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h10m19s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h10m34s][HUGECTR][INFO]: Iter: 19000 Time(1000 iters): 21.541020s Loss: 0.208731 lr:0.060059\n",
      "[12d70h10m49s][HUGECTR][INFO]: Iter: 20000 Time(1000 iters): 15.051771s Loss: 0.206068 lr:0.056246\n",
      "[12d70h20m40s][HUGECTR][INFO]: Iter: 21000 Time(1000 iters): 15.067925s Loss: 0.205040 lr:0.052559\n",
      "[12d70h20m10s][HUGECTR][INFO]: Evaluation, AUC: 0.945471\n",
      "[12d70h20m10s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.037830s\n",
      "[12d70h20m10s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h20m10s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h20m10s][HUGECTR][INFO]: Done\n",
      "[12d70h20m11s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h20m11s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h20m11s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h20m11s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h20m11s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h20m26s][HUGECTR][INFO]: Iter: 22000 Time(1000 iters): 22.271977s Loss: 0.199577 lr:0.048997\n",
      "[12d70h20m41s][HUGECTR][INFO]: Iter: 23000 Time(1000 iters): 15.047657s Loss: 0.194625 lr:0.045559\n",
      "[12d70h20m56s][HUGECTR][INFO]: Iter: 24000 Time(1000 iters): 15.054897s Loss: 0.197816 lr:0.042247\n",
      "[12d70h30m20s][HUGECTR][INFO]: Evaluation, AUC: 0.946273\n",
      "[12d70h30m20s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.023635s\n",
      "[12d70h30m20s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h30m20s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h30m20s][HUGECTR][INFO]: Done\n",
      "[12d70h30m40s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h30m40s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h30m40s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h30m40s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h30m40s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h30m19s][HUGECTR][INFO]: Iter: 25000 Time(1000 iters): 22.792095s Loss: 0.195353 lr:0.039059\n",
      "[12d70h30m34s][HUGECTR][INFO]: Iter: 26000 Time(1000 iters): 15.069135s Loss: 0.194946 lr:0.035997\n",
      "[12d70h30m49s][HUGECTR][INFO]: Iter: 27000 Time(1000 iters): 15.044690s Loss: 0.196138 lr:0.033060\n",
      "[12d70h30m55s][HUGECTR][INFO]: Evaluation, AUC: 0.946479\n",
      "[12d70h30m55s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.036560s\n",
      "[12d70h30m55s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h30m55s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h30m55s][HUGECTR][INFO]: Done\n",
      "[12d70h30m56s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h30m56s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h30m56s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h30m56s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h30m56s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h40m11s][HUGECTR][INFO]: Iter: 28000 Time(1000 iters): 21.477826s Loss: 0.196544 lr:0.030247\n",
      "[12d70h40m26s][HUGECTR][INFO]: Iter: 29000 Time(1000 iters): 15.047754s Loss: 0.192916 lr:0.027560\n",
      "[12d70h40m41s][HUGECTR][INFO]: Iter: 30000 Time(1000 iters): 15.076476s Loss: 0.193249 lr:0.024998\n",
      "[12d70h40m47s][HUGECTR][INFO]: Evaluation, AUC: 0.946866\n",
      "[12d70h40m47s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.019900s\n",
      "[12d70h40m47s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h40m47s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h40m47s][HUGECTR][INFO]: Done\n",
      "[12d70h40m47s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h40m47s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h40m47s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h40m47s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h40m47s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h50m20s][HUGECTR][INFO]: Iter: 31000 Time(1000 iters): 21.420334s Loss: 0.191549 lr:0.022560\n",
      "[12d70h50m17s][HUGECTR][INFO]: Iter: 32000 Time(1000 iters): 15.056377s Loss: 0.192337 lr:0.020248\n",
      "[12d70h50m32s][HUGECTR][INFO]: Iter: 33000 Time(1000 iters): 15.049432s Loss: 0.190889 lr:0.018060\n",
      "[12d70h50m38s][HUGECTR][INFO]: Evaluation, AUC: 0.947067\n",
      "[12d70h50m38s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.038870s\n",
      "[12d70h50m39s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h50m39s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h50m39s][HUGECTR][INFO]: Done\n",
      "[12d70h50m39s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h50m39s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h50m39s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h50m39s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h50m39s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h50m54s][HUGECTR][INFO]: Iter: 34000 Time(1000 iters): 21.957504s Loss: 0.190454 lr:0.015998\n",
      "[12d70h60m90s][HUGECTR][INFO]: Iter: 35000 Time(1000 iters): 15.051283s Loss: 0.188163 lr:0.014061\n",
      "[12d70h60m24s][HUGECTR][INFO]: Iter: 36000 Time(1000 iters): 15.057633s Loss: 0.192510 lr:0.012248\n",
      "[12d70h60m31s][HUGECTR][INFO]: Evaluation, AUC: 0.947169\n",
      "[12d70h60m31s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.039515s\n",
      "[12d70h60m31s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h60m31s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h60m31s][HUGECTR][INFO]: Done\n",
      "[12d70h60m31s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h60m31s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h60m31s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h60m31s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h60m31s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h60m46s][HUGECTR][INFO]: Iter: 37000 Time(1000 iters): 21.491865s Loss: 0.190069 lr:0.010561\n",
      "[12d70h70m10s][HUGECTR][INFO]: Iter: 38000 Time(1000 iters): 15.070367s Loss: 0.192338 lr:0.008999\n",
      "[12d70h70m16s][HUGECTR][INFO]: Iter: 39000 Time(1000 iters): 15.056408s Loss: 0.189535 lr:0.007561\n",
      "[12d70h70m22s][HUGECTR][INFO]: Evaluation, AUC: 0.947164\n",
      "[12d70h70m22s][HUGECTR][INFO]: Eval Time for 1000 iters: 5.993091s\n",
      "[12d70h70m22s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h70m22s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h70m22s][HUGECTR][INFO]: Done\n",
      "[12d70h70m22s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h70m22s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h70m22s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h70m22s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h70m22s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h70m38s][HUGECTR][INFO]: Iter: 40000 Time(1000 iters): 21.440558s Loss: 0.188189 lr:0.006249\n",
      "[12d70h70m53s][HUGECTR][INFO]: Iter: 41000 Time(1000 iters): 15.057426s Loss: 0.187295 lr:0.005061\n",
      "[12d70h80m80s][HUGECTR][INFO]: Iter: 42000 Time(1000 iters): 15.075448s Loss: 0.188529 lr:0.003999\n",
      "[12d70h80m14s][HUGECTR][INFO]: Evaluation, AUC: 0.947195\n",
      "[12d70h80m14s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.011289s\n",
      "[12d70h80m14s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h80m14s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h80m14s][HUGECTR][INFO]: Done\n",
      "[12d70h80m14s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h80m14s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h80m14s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h80m14s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h80m14s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h80m29s][HUGECTR][INFO]: Iter: 43000 Time(1000 iters): 21.454947s Loss: 0.188799 lr:0.003062\n",
      "[12d70h80m44s][HUGECTR][INFO]: Iter: 44000 Time(1000 iters): 15.055168s Loss: 0.190610 lr:0.002249\n",
      "[12d70h80m59s][HUGECTR][INFO]: Iter: 45000 Time(1000 iters): 15.067865s Loss: 0.191055 lr:0.001562\n",
      "[12d70h90m50s][HUGECTR][INFO]: Evaluation, AUC: 0.947241\n",
      "[12d70h90m50s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.046591s\n",
      "[12d70h90m50s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h90m50s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h90m50s][HUGECTR][INFO]: Done\n",
      "[12d70h90m60s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h90m60s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h90m60s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h90m60s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h90m60s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h90m21s][HUGECTR][INFO]: Iter: 46000 Time(1000 iters): 21.669764s Loss: 0.187626 lr:0.001000\n",
      "[12d70h90m36s][HUGECTR][INFO]: Iter: 47000 Time(1000 iters): 15.044369s Loss: 0.188257 lr:0.000562\n",
      "[12d70h90m51s][HUGECTR][INFO]: Iter: 48000 Time(1000 iters): 15.050518s Loss: 0.190723 lr:0.000250\n",
      "[12d70h90m57s][HUGECTR][INFO]: Evaluation, AUC: 0.947264\n",
      "[12d70h90m57s][HUGECTR][INFO]: Eval Time for 1000 iters: 6.008485s\n",
      "[12d70h90m57s][HUGECTR][INFO]: Rank0: Dump hash table from GPU0\n",
      "[12d70h90m57s][HUGECTR][INFO]: Rank0: Write hash table <key,value> pairs to file\n",
      "[12d70h90m57s][HUGECTR][INFO]: Done\n",
      "[12d70h90m58s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[12d70h90m58s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[12d70h90m58s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[12d70h90m58s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[12d70h90m58s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[12d70h10m13s][HUGECTR][INFO]: Iter: 49000 Time(1000 iters): 21.945730s Loss: 0.188774 lr:0.000062\n"
     ]
    }
   ],
   "source": [
    "!CUDA_VISIBLE_DEVICES=0 python3 hugectr_dlrm_movielens.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Answer item similarity with DLRM embedding\n",
    "\n",
    "In this section, we demonstrate how the output of HugeCTR training can be used to carry out simple inference tasks. Specifically, we will show that the movie embeddings can be used for simple item-to-item similarity queries. Such a simple inference can be used as an efficient candidate generator to generate a small set of candidates prior to deep learning model re-ranking. \n",
    "\n",
    "First, we read the embedding tables and extract the movie embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import struct \n",
    "import pickle\n",
    "import numpy as np\n",
    "\n",
    "key_type = 'I64'\n",
    "key_type_map = {\"I32\": [\"I\", 4], \"I64\": [\"q\", 8]}\n",
    "\n",
    "embedding_vec_size = 64\n",
    "\n",
    "HUGE_CTR_VERSION = 2.21 # set HugeCTR version here, 2.2 for v2.2, 2.21 for v2.21\n",
    "\n",
    "if HUGE_CTR_VERSION <= 2.2:\n",
    "    each_key_size = key_type_map[key_type][1] + key_type_map[key_type][1] + 4 * embedding_vec_size\n",
    "else:\n",
    "    each_key_size = key_type_map[key_type][1] + 8 + 4 * embedding_vec_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_table = {}\n",
    "        \n",
    "with open(\"./hugeCTR_saved_model_DLRM/0_sparse_9000.model\" + \"/key\", 'rb') as key_file, \\\n",
    "     open(\"./hugeCTR_saved_model_DLRM/0_sparse_9000.model\" + \"/emb_vector\", 'rb') as vec_file:\n",
    "    try:\n",
    "        while True:\n",
    "            key_buffer = key_file.read(key_type_map[key_type][1])\n",
    "            vec_buffer = vec_file.read(4 * embedding_vec_size)\n",
    "            if len(key_buffer) == 0 or len(vec_buffer) == 0:\n",
    "                break\n",
    "            key = struct.unpack(key_type_map[key_type][0], key_buffer)[0]\n",
    "            values = struct.unpack(str(embedding_vec_size) + \"f\", vec_buffer)\n",
    "\n",
    "            embedding_table[key] = values\n",
    "\n",
    "    except BaseException as error:\n",
    "        print(error)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_embedding = np.zeros((26744, embedding_vec_size), dtype='float')\n",
    "for i in range(len(embedding_table[1])):\n",
    "    item_embedding[i] = embedding_table[1][i]\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Answer nearest neighbor queries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cdist\n",
    "\n",
    "def find_similar_movies(nn_movie_id, item_embedding, k=10, metric=\"euclidean\"):\n",
    "    #find the top K similar items according to one of the distance metric: cosine or euclidean\n",
    "    sim = 1-cdist(item_embedding, item_embedding[nn_movie_id].reshape(1, -1), metric=metric)\n",
    "   \n",
    "    return sim.squeeze().argsort()[-k:][::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./mappings.pickle', 'rb') as handle:\n",
    "    movies_mapping = pickle.load(handle)[\"items\"]\n",
    "\n",
    "nn_to_movies = movies_mapping\n",
    "movies_to_nn = {}\n",
    "for i in range(len(movies_mapping)):\n",
    "    movies_to_nn[movies_mapping[i]] = i\n",
    "\n",
    "import pandas as pd\n",
    "movies = pd.read_csv(\"./data/ml-20m/movies.csv\", index_col=\"movieId\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query:  Toy Story (1995) Adventure|Animation|Children|Comedy|Fantasy\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Jumanji (1995) Adventure|Children|Fantasy\n",
      "Similar movies: \n",
      "2 Jumanji (1995) Adventure|Children|Fantasy\n",
      "1333 Birds, The (1963) Horror|Thriller\n",
      "1240 Terminator, The (1984) Action|Sci-Fi|Thriller\n",
      "1089 Reservoir Dogs (1992) Crime|Mystery|Thriller\n",
      "593 Silence of the Lambs, The (1991) Crime|Horror|Thriller\n",
      "1387 Jaws (1975) Action|Horror\n",
      "112 Rumble in the Bronx (Hont faan kui) (1995) Action|Adventure|Comedy|Crime\n",
      "1198 Raiders of the Lost Ark (Indiana Jones and the Raiders of the Lost Ark) (1981) Action|Adventure\n",
      "1036 Die Hard (1988) Action|Crime|Thriller\n",
      "1246 Dead Poets Society (1989) Drama\n",
      "=================================\n",
      "\n",
      "Query:  Grumpier Old Men (1995) Comedy|Romance\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Waiting to Exhale (1995) Comedy|Drama|Romance\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Father of the Bride Part II (1995) Comedy\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Heat (1995) Action|Crime|Thriller\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Sabrina (1995) Comedy|Romance\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Tom and Huck (1995) Adventure|Children\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n",
      "Query:  Sudden Death (1995) Action\n",
      "Similar movies: \n",
      "110510 Série noire (1979) Film-Noir\n",
      "32361 Come and Get It (1936) Drama\n",
      "67999 Global Metal (2008) Documentary\n",
      "69356 Zulu Dawn (1979) Action|Drama|Thriller|War\n",
      "69381 Hitman, The (1991) Action|Crime|Thriller\n",
      "69442 Pekka ja Pätkä neekereinä (1960) Comedy\n",
      "69818 Franklyn (2008) Drama|Fantasy|Romance|Thriller\n",
      "70344 Cold Souls (2009) Comedy|Drama\n",
      "70495 Kill Buljo: The Movie (2007) Action|Comedy\n",
      "70864 Botched (2007) Comedy|Crime|Horror|Thriller\n",
      "=================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for movie_ID in range(1,10):\n",
    "    try:\n",
    "        print(\"Query: \", movies.loc[movie_ID][\"title\"], movies.loc[movie_ID][\"genres\"])\n",
    "\n",
    "        print(\"Similar movies: \")\n",
    "        similar_movies = find_similar_movies(movies_to_nn[movie_ID], item_embedding)\n",
    "\n",
    "        for i in similar_movies:\n",
    "            print(nn_to_movies[i], movies.loc[nn_to_movies[i]][\"title\"], movies.loc[nn_to_movies[i]][\"genres\"])\n",
    "        print(\"=================================\\n\")\n",
    "    except Exception as e:\n",
    "        pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
